import warnings
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

def one_hot(
    labels: torch.Tensor,
    num_classes: int,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.
    Args:
        labels: tensor with labels of shape :math:`(N, *)`, where N is batch size.
          Each value is an integer representing correct classification.
        num_classes: number of classes in labels.
        device: the desired device of returned tensor.
        dtype: the desired data type of returned tensor.
    Returns:
        the labels in one hot tensor of shape :math:`(N, C, *)`,
    """
    if not isinstance(labels, torch.Tensor):
        raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")

    if not labels.dtype == torch.int64:
        raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")

    if num_classes < 1:
        raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes))

    shape = labels.shape
    one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)

    return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps


class FocalLoss(nn.Module):
    # Taken from https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
    r"""Criterion that computes Focal loss.
    According to :cite:`lin2018focal`, the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    Where:
       - :math:`p_t` is the model's estimated probability for each class.
    Args:
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
          output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
          will be applied, ``'mean'``: the sum of the output will be divided by
          the number of elements in the output, ``'sum'``: the output will be
          summed.
        eps: Deprecated: scalar to enforce numerical stability. This is no longer
          used.
    Shape:
        - Input: :math:`(N, C, *)` where C = number of classes.
        - Target: :math:`(N, *)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.
    """

    def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 0.0, reduction: str = 'mean', eps: Optional[float] = None) -> None:
        super().__init__()
        self.alpha: Optional[torch.Tensor] = alpha
        self.gamma: float = gamma
        self.reduction: str = reduction
        self.eps: Optional[float] = eps

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return self.focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps)

    def focal_loss(self,
            input: torch.Tensor,
            target: torch.Tensor,
            alpha: Optional[torch.Tensor],
            gamma: float = 2.0,
            reduction: str = 'none',
            eps: Optional[float] = None,
    ) -> torch.Tensor:
        # Taken from https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
        r"""Criterion that computes Focal loss.
        According to :cite:`lin2018focal`, the Focal loss is computed as follows:
        .. math::
            \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
        Where:
           - :math:`p_t` is the model's estimated probability for each class.
        Args:
            input: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
            target: labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
            alpha: Weighting factor :math:`\alpha \in [0, 1]`.
            gamma: Focusing parameter :math:`\gamma >= 0`.
            reduction: Specifies the reduction to apply to the
              output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
              will be applied, ``'mean'``: the sum of the output will be divided by
              the number of elements in the output, ``'sum'``: the output will be
              summed.
            eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.
        Return:
            the computed loss.
        """
        if eps is not None and not torch.jit.is_scripting():
            warnings.warn(
                "`focal_loss` has been reworked for improved numerical stability "
                "and the `eps` argument is no longer necessary",
                DeprecationWarning,
                stacklevel=2,
            )

        if not isinstance(input, torch.Tensor):
            raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

        if not len(input.shape) >= 2:
            raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input.shape}")

        if input.size(0) != target.size(0):
            raise ValueError(
                f'Expected input batch_size ({input.size(0)}) to match target batch_size ({target.size(0)}).')

        n = input.size(0)
        out_size = (n,) + input.size()[2:]
        if target.size()[1:] != input.size()[2:]:
            raise ValueError(f'Expected target size {out_size}, got {target.size()}')

        if not input.device == target.device:
            raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")

        # compute softmax over the classes axis
        input_soft: torch.Tensor = F.softmax(input, dim=1)
        log_input_soft: torch.Tensor = F.log_softmax(input, dim=1)

        # create the labels one hot tensor
        target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device,
                                               dtype=input.dtype)

        # compute the actual focal loss
        weight = torch.pow(-input_soft + 1.0, gamma)

        focal = -weight * log_input_soft
        loss_tmp = torch.einsum('bc...,bc...->b...', (target_one_hot, focal))

        if reduction == 'none':
            loss = loss_tmp
        elif reduction == 'mean':
            if self.alpha is not None:
                # Reweighting
                loss_tmp = self.alpha[target] * loss_tmp
                loss = torch.sum(loss_tmp)/self.alpha[target].sum()
            else:
                loss = torch.mean(loss_tmp)
        elif reduction == 'sum':
            loss = torch.sum(loss_tmp)
        else:
            raise NotImplementedError(f"Invalid reduction mode: {reduction}")
        return loss